import argparse
import os
import json
import time
import re
import base64
import anthropic
from typing import Any
from copy import deepcopy
from tti.environment.webgym.helper_functions import (
    # PseudoPage,
    # gitlab_get_project_memeber_role,
    llm_fuzzy_match,
    llm_ua_match,
    reddit_get_post_url,
    shopping_get_latest_order_url,
    shopping_get_sku_latest_review_author,
    shopping_get_sku_latest_review_rating,
)
from nltk.tokenize import word_tokenize
import logging
# from openai import OpenAI
# import boto3

import collections
import urllib
import html



USER_PROMPT = """TASK: <task>
Result Response: <answer>
<num> screenshots at the end: """

USER_REFERENCE_PROMPT = """TASK: <task>
Result Response: <answer>
Reference Response: <reference>
<num> screenshots at the end: """


def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

# # WebArena evaluator functions
# def llm_fuzzy_match(pred: str, reference: str, question: str, claude_client) -> float:
#     """Check whether the prediction matches the reference with Claude3"""
#     messages: list[dict[str, Any]] = []
#     # construct the question to ask
#     message = "Help a teacher to grade the answer of a student given a question. Keep in mind that the student may use different phrasing or wording to answer the question. The goal is to evaluate whether the answer is semantically equivalent to the reference answer.\n"
#     message += f"question: {question}\n"
#     message += f"reference answer: {reference}\n"
#     message += "all the string 'N/A' that you see is a special sequence that means 'not achievable'\n"
#     message += f"student answer: {pred}\n"
#     message += "Conclude the judgement by correct/incorrect/partially correct."
#     messages = [
#         {
#             'role': 'user',
#             'content': [
#                 {'type': 'text', 'text': message}
#             ]
#         }
#     ]

#     for i in range(3):
#         try:
#             # print('Calling Claude3 API to get the auto evaluation......')
#             request_body = {
#                 "anthropic_version": "bedrock-2023-05-31",
#                 "max_tokens": 1000,
#                 "messages": messages,
#                 "temperature": 0,
#                 "top_p": 0.7,
#                 'system' : "You are a helpful assistant"
#             }
#             claude_response = claude_client.invoke_model(
#                 modelId = "anthropic.claude-3-sonnet-20240229-v1:0",
#                 body=json.dumps(request_body)
#             )
#             result = json.loads(claude_response.get("body").read())
#             break
#         except Exception as e:
#             print(e)
#             print(e.__ne__)
#             if i == 2:
#                 return 0
#             if e == 'ModelErrorException':
#                 time.sleep(10)
#             elif type(e).__ne__ == 'APIError':
#                 time.sleep(15)
#             elif type(e).__ne__ == 'InvalidRequestError':
#                 exit(0)
#             else:
#                 time.sleep(10)
#     response = result['content'][0]['text'].lower()

#     if "partially correct" in response or "incorrect" in response:
#         return 0.0
#     else:
#         assert "correct" in response
#         return 1.0


# def llm_ua_match(pred: str, reference: str, question: str, claude_client) -> float:
#     """Check whether the prediction matches the reference with Claude"""
#     messages: list[dict[str, Any]] = []
#     # construct the question to ask
#     message = ""
#     message += f"task: {question}\n"
#     message += f"actual unachievable reason: {reference}\n"
#     message += f"reported unachievable reason: {pred}\n"
#     message += (
#         "The task described above is inherently unachievable due to the reason specified under 'actual unachievable reason'. "
#         "An individual previously attempted this task and was unable to complete it. They provided a reason for their failure, "
#         "which is listed under 'reported unachievable reason'. Your role is to review both the actual and reported reasons. "
#         "Determine if the reported reason aligns with the actual reason, even if implicitly. "
#         "If the stated reason is in line with the actual reason, respond with 'same'. Otherwise, respond with 'different'."
#     )
#     messages = [
#         {
#             'role': 'user',
#             'content': [
#                 {'type': 'text', 'text': message}
#             ]
#         }
#     ]

#     while True:
#         try:
#             # print('Calling Claude3 API to get the auto evaluation......')
#             request_body = {
#                 "anthropic_version": "bedrock-2023-05-31",
#                 "max_tokens": 1000,
#                 "messages": messages,
#                 "temperature": 0,
#                 "top_p": 0.7,
#                 'system' : "You are a helpful assistant"
#             }
#             claude_response = claude_client.invoke_model(
#                 modelId = "anthropic.claude-3-sonnet-20240229-v1:0",
#                 body=json.dumps(request_body)
#             )
#             result = json.loads(claude_response.get("body").read())
#             break
#         except Exception as e:
#             print(e)
#             print(e.__ne__)
#             if e == 'ModelErrorException':
#                 time.sleep(10)
#             elif type(e).__ne__ == 'APIError':
#                 time.sleep(15)
#             elif type(e).__ne__ == 'InvalidRequestError':
#                 exit(0)
#             else:
#                 time.sleep(10)
#     response = result['content'][0]['text'].lower()
    
    
#     if "different" in response:
#         return 0.0
#     else:
#         assert "same" in response
#         return 1.0


class StringEvaluator():
    """Check whether the answer is correct with:
    exact match: the answer is exactly the same as the reference answer
    must include: each phrase in the reference answer must be included in the answer
    fuzzy match: the answer is similar to the reference answer, using LLM judge
    """

    def __init__(self):
        self.client = None

    def clean_answer(self, answer: str) -> str:
        answer = answer.strip()
        if answer.startswith("'") and answer.endswith("'"):
            answer = answer[1:-1]
        elif answer.startswith('"') and answer.endswith('"'):
            answer = answer[1:-1]
        return answer.lower()

    def exact_match(self, ref: str, pred: str) -> float:
        return float(
            self.clean_answer(pred)
            == self.clean_answer(ref)
        )

    def must_include(self, ref: str, pred: str, tokenize: bool = False) -> float:
        clean_ref = self.clean_answer(ref)
        clean_pred = self.clean_answer(pred)
        # tokenize the answer if the ref is a single word
        # prevent false positive (e.g, 0)
        # print("[Must include eval GT]",clean_ref,"\n[Must include eval PRED]",clean_pred,"\n[Must include eval SCORE]", float(clean_ref in clean_pred))
        if (
            tokenize
            and len(clean_ref) == 1
            and len(word_tokenize(clean_ref)) == 1
        ):
            tok_pred = word_tokenize(clean_pred)
            return float(clean_ref in tok_pred)
        else:
            return float(clean_ref in clean_pred)

    def fuzzy_match(self, ref: str, pred: str, intent: str, client) -> float:
        return llm_fuzzy_match(pred, ref, intent)#, client)

    def ua_match(self, ref: str, pred: str, intent: str, client) -> float:
        return llm_ua_match(pred, ref, intent)#, client)

    def __call__(
        self,
        task_content,
        answer,
        eval_config,
        driver
    ) -> float:
        pred = self.clean_answer(answer)
        
        score = 1.0
        for approach, value in eval_config["reference_answers"].items():
            match approach:
                case "exact_match":
                    score *= self.exact_match(value, pred)

                case "must_include":
                    assert isinstance(value, list)
                    for must_value in value:
                        score *= self.must_include(
                            ref=must_value,
                            pred=pred,
                            tokenize=(len(value) == 1),
                        )
                case "fuzzy_match":
                    intent = task_content
                    if value == "N/A":
                        # if the instruction only asks the model to generate N/A when encountering an unachievable task
                        # without more concrete reasons
                        score *= self.exact_match(ref=value, pred=pred)
                        # if the instruction also asks the model to generate the reason why the task is unachievable
                        # this should be the default as it will prevent false positive N/A`
                        if score != 1:
                            score = 1.0 * self.ua_match(
                                intent=task_content,
                                ref=eval_config["string_note"],
                                pred=pred,
                                client=self.client,
                            )
                    else:
                        assert isinstance(value, list)
                        for reference in value:
                            score *= self.fuzzy_match(
                                ref=reference, pred=pred, intent=intent, client=self.client,
                            )
        return score
        
def replace_ip_and_port(target_url, url_to_modify):
    # Extract IP and port from target URL
    ip_port_pattern = r'http://([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+):([0-9]+)'
    target_match = re.search(ip_port_pattern, target_url)
    
    if not target_match:
        return "Target URL doesn't match expected format with IP and port"
    
    target_ip = target_match.group(1)
    target_port = target_match.group(2)
    
    # Replace IP and port in the second URL
    modified_url = re.sub(
        ip_port_pattern,
        f'http://{target_ip}:{target_port}',
        url_to_modify
    )
    return modified_url

class URLEvaluator():
    """Check URL matching"""

    def __call__(
            self,
            task_content,
            answer,
            eval_config,
            driver
            ):
        def clean_url(url: str) -> str:
            url = str(url)
            url = url.rstrip("/")
            return url

        def parse_url(url: str) -> tuple[str, dict[str, list[str]]]:
            """Parse a URL into its base, path, and query components."""
            parsed_url = urllib.parse.urlparse(url)
            base_path = parsed_url.netloc + parsed_url.path
            query = urllib.parse.parse_qs(parsed_url.query)
            return base_path, query      

        def parse_urls(
            urls: list[str],
        ) -> tuple[list[str], dict[str, set[str]]]:
            """Parse a list of URLs."""
            base_paths = []
            queries = collections.defaultdict(set)
            for url in urls:
                base_path, query = parse_url(url)
                base_paths.append(base_path)
                for k, v in query.items():
                    queries[k].update(v)
            return base_paths, queries

        pred = clean_url(driver.current_url)
        ref_urls = replace_ip_and_port(driver.current_url, eval_config["reference_url"]).split(" |OR| ")
        ref_urls = [clean_url(url) for url in ref_urls]
        matching_rule = eval_config.get("url_note", "GOLD in PRED")

        # print(f"[EVAL URL pred urls: {pred}]")
        # print(f"[EVAL URL ref_urls: {ref_urls}]")

        if matching_rule == "GOLD in PRED":
            ref_base_paths, ref_queries = parse_urls(ref_urls)
            pred_base_paths, pred_query = parse_url(pred)

            # print(f"ref_base_paths: {ref_base_paths}")
            # print(f"ref_queries: {ref_queries}")
            # print(f"pred_base_paths: {pred_base_paths}")
            # print(f"pred_query: {pred_query}")

            base_score = float(
                any(
                    [
                        ref_base_path in pred_base_paths
                        for ref_base_path in ref_base_paths
                    ]
                )
            )
            query_score = 1.0
            for k, possible_values in ref_queries.items():
                query_score *= float(
                    any(
                        possible_ref_value in pred_query.get(k, [])
                        for possible_ref_value in possible_values
                    )
                )
            # print(f"base_score: {base_score}", f"query_score: {query_score}")
            score = base_score * query_score
        else:
            raise ValueError(f"Unknown matching rule: {matching_rule}")
        return score

class HTMLContentEvaluator():
    """Check whether the contents appear in the page"""

    def clean_answer(self, answer: str) -> str:
        answer = answer.strip()
        if answer.startswith("'") and answer.endswith("'"):
            answer = answer[1:-1]
        elif answer.startswith('"') and answer.endswith('"'):
            answer = answer[1:-1]
        return answer.lower()

    def exact_match(self, ref: str, pred: str) -> float:
        return float(
            self.clean_answer(pred)
            == self.clean_answer(ref)
        )

    def must_include(self, ref: str, pred: str, tokenize: bool = False) -> float:
            clean_ref = self.clean_answer(ref)
            clean_pred = self.clean_answer(pred)
            # tokenize the answer if the ref is a single word
            # prevent false positive (e.g, 0)
            if (
                tokenize
                and len(clean_ref) == 1
                and len(word_tokenize(clean_ref)) == 1
            ):
                tok_pred = word_tokenize(clean_pred)
                return float(clean_ref in tok_pred)
            else:
                return float(clean_ref in clean_pred)

    def __call__(
            self,
            task_content,
            answer,
            eval_config,
            driver
            ):
        targets = eval_config["program_html"]
        score = 1.0
        for target in targets:
            target_url = replace_ip_and_port(driver.current_url, target["url"])
            if target_url.startswith("func"):
                func = target_url.split("func:")[1]
                func = func.replace("__last_url__", driver.current_url)
                target_url = eval(func)
            
            locator = target["locator"]

            if target_url != "last":
                # print("[EVAL PROGRAM HTML goto]", target_url)
                driver.get(target_url)
                time.sleep(3)

            # print(f"[EVAL PROGRAM HTML locator: {locator}]")
            
            if not locator.strip():
                selected_element = driver.page_source
            elif locator.startswith("document.") or locator.startswith("[...document."):
                if "prep_actions" in target:
                    try:
                        for prep_action in target["prep_actions"]:
                            driver.execute_script(f"return {prep_action}")
                    except Exception:
                        pass
                try:
                    # print(f"return {locator}")
                    selected_element = str(driver.execute_script(f"return {locator}"))
                    # print('real selected element', selected_element)
                    if not selected_element:
                        selected_element = ""
                except Exception:
                    selected_element = ""
            elif locator.startswith("func:"):
                func = locator.split("func:")[1]
                func = func.replace("__page__", "page")
                selected_element = eval(func)
            else:
                raise ValueError(f"[Unknown locator: {locator}]")
            
            # print(f"[EVAL PROGRAM HTML selected_element: {selected_element[:30]}]")
            
            selected_element = html.unescape(selected_element)

            if "exact_match" in target["required_contents"]:
                required_contents = target["required_contents"]["exact_match"]
                cur_score = self.exact_match(ref=required_contents, pred=selected_element)
                score *= float(cur_score)
            elif "must_include" in target["required_contents"]:
                required_contents = target["required_contents"]["must_include"]
                assert isinstance(required_contents, list)
                for content in required_contents:
                    content_or = content.split(" |OR| ")
                    cur_score = any([self.must_include(ref=content, pred=selected_element, tokenize=False) for content in content_or])
                    score *= float(cur_score)
            else:
                raise ValueError(f"Unknown required_contents: {target['required_contents'].keys()}")
        return score

import threading
import concurrent

# Modification for webarena_batch_eval in tti/environment/webgym/utils_eval.py
def webarena_batch_eval(trajectories, batch_obs, batch_eval_info, batch_env):
    """
    Modified webarena batch evaluation to correctly handle indexing with failed trajectories
    and only evaluate trajectories with ANSWER actions
    """
    job_args = []
    valid_indices = []  # Track which trajectory indices have valid jobs
    
    for i in range(len(trajectories)):
        # Skip empty trajectories or invalid environments
        if batch_obs[i] is None or batch_eval_info[i] is None or batch_env.envs[i].task is None or len(trajectories[i]) == 0:
            continue
            
        try:
            # Only evaluate if ANSWER was provided
            if 'answer' not in batch_eval_info[i]:
                continue
                
            obs, eval_info, env = batch_obs[i], batch_eval_info[i], batch_env.envs[i]
            eval_config = env.task['eval']
            task_content = env.task['ques']
            answer = eval_info.get('answer', 'N/A')
            driver = env.driver_task
            
            # Add to job args and track the original index
            job_args.append((task_content, answer, eval_config, driver, batch_env.verbose))
            valid_indices.append(i)
        except Exception as e:
            if batch_env.verbose:
                logging.error(f"Error preparing evaluation for trajectory {i}: {str(e)}")
    
    # Only process if we have valid jobs
    if job_args:
        with concurrent.futures.ThreadPoolExecutor() as executor:
            # Submit jobs for parallel processing
            jobs = [executor.submit(webarena_eval, *jargs) for jargs in job_args]
            rewards = [job.result() for job in jobs]
            
            # Map results back to original trajectories using valid_indices
            for idx, reward in enumerate(rewards):
                traj_idx = valid_indices[idx]
                if reward is not None and traj_idx < len(trajectories) and len(trajectories[traj_idx]) > 0:
                    trajectories[traj_idx][-1]['reward'] = reward
    
    return trajectories

def webarena_eval(task_content, answer, eval_config, driver, verbose=False):
    """
    Evaluate a single trajectory using Webarena evaluation criteria.
    Now updated to handle the verbose parameter correctly.
    """
    try:
        eval_types = eval_config["eval_types"]
        evaluators = []
        for eval_type in eval_types:
            match eval_type:
                case "string_match":
                    evaluators.append(StringEvaluator())
                case "url_match":
                    evaluators.append(URLEvaluator())
                case "program_html":
                    evaluators.append(HTMLContentEvaluator())
                case _:
                    raise ValueError(f"eval_type {eval_type} is not supported")
    
        score = 1.0
        for evaluator in evaluators:
            cur_score = evaluator(task_content, answer, eval_config, driver)
            score *= cur_score
        if verbose:
            logging.info(f"[WEBARENA EVAL SUCCEED] Task: {task_content} Answer: {answer} Config: {eval_config} Result: {score}")
    except Exception as e:
        score = 0.0
        if verbose:
            logging.info(f"[WEBARENA EVAL FAIL] Task: {task_content} Answer: {answer} Config: {eval_config} Result: {score}")
            logging.error(f"Evaluation error: {str(e)}")
    return score

def auto_eval_by_claude_console(it_messages, process_dir, img_path, anthropic_api_key, api_model, img_num, task, evaluator_prompt):
    # Optionally extract a reference answer if provided.
    reference = None
    if task.get('eval') is not None and task['eval'] is not None and task['eval'].get('reference_answer_raw_annotation') is not None:
        reference = task['eval']['reference_answer_raw_annotation']
    
    if len(it_messages) == 0:
        return None

    # Extract the task content.
    task_info = it_messages[0]["content"]
    if isinstance(task_info, list):
        task_info = task_info[0]["text"]
    assert 'Now given a task' in task_info, "Task content not found"
    pattern = r"Now given a task:(.+?)Please interact with"
    matches = re.search(pattern, task_info)
    task_content = matches.group(1).strip() if matches else ""

    # Extract the answer content.
    ans_info = it_messages[-1]["content"]
    if 'Action: ANSWER' not in ans_info:
        return 0
    pattern_ans = r"ANSWER[; ]+\[?(.[^\]]*)\]?"
    matches_ans = re.search(pattern_ans, ans_info)
    answer_content = matches_ans.group(1).strip() if matches_ans else ""

    # Gather the most recent screenshot images.
    screenshots = [int(f[10:].split('.png')[0]) for f in os.listdir(process_dir) if '.png' in f]
    screenshots.sort()
    screenshots = screenshots[-img_num:]
    
    whole_content_img = []
    for screenshot_id in screenshots:
        cur_img_path = os.path.join(process_dir, f'screenshot{screenshot_id}.png')
        b64_img = encode_image(cur_img_path)
        whole_content_img.append({
            'type': 'image',
            'source': {'type': 'base64', 'media_type': 'image/png', 'data': b64_img}
        })

    # Prepare the full prompt for evaluation.
    user_prompt_tmp = USER_PROMPT.replace('<task>', task_content)
    user_prompt_tmp = user_prompt_tmp.replace('<answer>', answer_content)
    user_prompt_tmp = user_prompt_tmp.replace('<num>', str(img_num))
    
    messages = [
        {
            'role': 'user',
            'content': (
                [{'type': 'text', 'text': user_prompt_tmp}]
                + whole_content_img +
                [{'type': 'text', 'text': "Your verdict:\n"}]
            )
        }
    ]
    
    # Initialize the Anthropic client.
    client = anthropic.Anthropic(api_key=anthropic_api_key)
    response = client.messages.create(
        model=api_model,
        max_tokens=1000,
        system=evaluator_prompt,
        thinking={"type": "enabled", "budget_tokens": 500},
        messages=messages,
        temperature=0  # Use temperature=0 for deterministic output; adjust if needed.
    )

    # Extract the text from the response.
    # Depending on the Anthropic API version, you may have a 'completion' key
    # or a content list. Adjust as needed.
    claude_3_res = response.get("completion", "")
    if not claude_3_res and isinstance(response.get("content"), list):
        # Fallback: use the text from the first content item.
        claude_3_res = response["content"][0].get("text", "")

    # (Optional) Replace inline image sources with a dummy URL for logging.
    print_message = messages[0]
    for idx in range(len(print_message['content'])):
        if print_message['content'][idx]['type'] == 'image':
            print_message['content'][idx]['source'] = {"url": "data:image/png;base64, b64_img"}
    
    # Determine the evaluation result based on the output text.
    auto_eval_res = 1 if ("SUCCESS" in claude_3_res and "NOT SUCCESS" not in claude_3_res) else 0
    return auto_eval_res, claude_3_res

def get_eval_prompt_gemma3(it_messages, process_dir, img_num):
    if len(it_messages) == 0:
        print("ERROR: No messages found for evaluation")
        return ""

    # Extract the task content.
    task_info = it_messages[0]["content"]
    if isinstance(task_info, list):
        task_info = task_info[0]["text"]
    assert 'Now given a task' in task_info, "Task content not found"
    
    pattern = r"Now given a task:(.+?)Please interact with"
    matches = re.search(pattern, task_info)
    task_content = matches.group(1).strip() if matches else ""

    # Extract the answer content.
    ans_info = it_messages[-1]["content"]
    if 'Action: ANSWER' not in ans_info:
        print("ERROR: No ANSWER found for evaluation")
        return 0
    pattern_ans = r"ANSWER[; ]+\[?(.[^\]]*)\]?"
    matches_ans = re.search(pattern_ans, ans_info)
    answer_content = matches_ans.group(1).strip() if matches_ans else ""

    # Gather the most recent screenshot images.
    screenshots = [int(f[10:].split('.png')[0]) for f in os.listdir(process_dir) if '.png' in f]
    screenshots.sort()
    screenshots = screenshots[-img_num:]
    
    whole_content_img = []
    for screenshot_id in screenshots:
        cur_img_path = os.path.join(process_dir, f'screenshot{screenshot_id}.png')
        b64_img = encode_image(cur_img_path)
        whole_content_img.append({
            'type': 'image',
            'source': {'type': 'base64', 'media_type': 'image/png', 'data': b64_img}
        })

    # Prepare the full prompt for evaluation.
    user_prompt_tmp = USER_PROMPT.replace('<task>', task_content)
    user_prompt_tmp = user_prompt_tmp.replace('<answer>', answer_content)
    user_prompt_tmp = user_prompt_tmp.replace('<num>', str(img_num))
    
    messages = [
        {
            'role': 'user',
            'content': (
                [{'type': 'text', 'text': user_prompt_tmp}]
                + whole_content_img +
                [{'type': 'text', 'text': "Your verdict:\n"}]
            )
        }
    ]
    
    # response = agent.get_reward(messages, evaluator_prompt)
    # print_message = messages[0]
    # for idx in range(len(print_message['content'])):
    #     if print_message['content'][idx]['type'] == 'image':
    #         print_message['content'][idx]['source'] = {"url": "data:image/png;base64, b64_img"}
    
    # # Determine the evaluation result based on the output text.
    # auto_eval_res = 1 if ("SUCCESS" in response and "NOT SUCCESS" not in response) else 0
    # return auto_eval_res, response
    
    return messages


# def evaluate(agent, messages, process_dir, img_path, task, driver, use_webarena_evaluator=False,
#              evaluator_model="claude-3-7-sonnet-20250219", evaluator_prompt=None,
#              evaluator_imgs=3, anthropic_api_key=""):
#     # # If using the webarena evaluator (if that path is unchanged), delegate to it.
#     if task.get('eval') is not None and task['eval'] is not None and use_webarena_evaluator:
#         # Assuming webarena_eval is defined elsewhere.
#         return webarena_eval(messages, task['eval'], driver, None)
#     else:
#         return auto_eval_by_gemma3(agent, messages, process_dir, img_path,
#                                 anthropic_api_key, api_model=evaluator_model,
#                                 img_num=evaluator_imgs, task=task,
#                                 evaluator_prompt=evaluator_prompt)
